{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7章　ブロック崩しBreakoutの再生プログラム    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# パッケージのimport\n",
    "import numpy as np\n",
    "from collections import deque\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "import gym\n",
    "from gym import spaces\n",
    "from gym.spaces.box import Box\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# パッケージのインポート\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "# 動画の描画関数の宣言\n",
    "# 参考URL http://nbviewer.jupyter.org/github/patrickmineault\n",
    "# /xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb\n",
    "from JSAnimation.IPython_display import display_animation\n",
    "from matplotlib import animation\n",
    "from IPython.display import display\n",
    "\n",
    "def display_frames_as_gif(frames):\n",
    "    \"\"\"\n",
    "    Displays a list of frames as a gif, with controls\n",
    "    \"\"\"\n",
    "    plt.figure(figsize=(frames[0].shape[1]/72.0*1, frames[0].shape[0]/72.0*1),\n",
    "               dpi=72)\n",
    "    patch = plt.imshow(frames[0])\n",
    "    plt.axis('off')\n",
    " \n",
    "    def animate(i):\n",
    "        patch.set_data(frames[i])\n",
    " \n",
    "    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),\n",
    "                                   interval=20)\n",
    " \n",
    "    anim.save('breakout.mp4')  # 動画のファイル名と保存です\n",
    "    display(display_animation(anim, default_mode='loop'))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 実行環境の設定\n",
    "# 参考：https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py\n",
    "\n",
    "import cv2\n",
    "cv2.ocl.setUseOpenCL(False)\n",
    "\n",
    "\n",
    "class NoopResetEnv(gym.Wrapper):\n",
    "    def __init__(self, env, noop_max=30):\n",
    "        '''工夫1のNo-Operationです。リセット後適当なステップの間何もしないようにし、\n",
    "        ゲーム開始の初期状態を様々にすることｆで、特定の開始状態のみで学習するのを防ぐ'''\n",
    "\n",
    "        gym.Wrapper.__init__(self, env)\n",
    "        self.noop_max = noop_max\n",
    "        self.override_num_noops = None\n",
    "        self.noop_action = 0\n",
    "        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'\n",
    "\n",
    "    def reset(self, **kwargs):\n",
    "        \"\"\" Do no-op action for a number of steps in [1, noop_max].\"\"\"\n",
    "        self.env.reset(**kwargs)\n",
    "        if self.override_num_noops is not None:\n",
    "            noops = self.override_num_noops\n",
    "        else:\n",
    "            noops = self.unwrapped.np_random.randint(\n",
    "                1, self.noop_max + 1)  # pylint: disable=E1101\n",
    "        assert noops > 0\n",
    "        obs = None\n",
    "        for _ in range(noops):\n",
    "            obs, _, done, _ = self.env.step(self.noop_action)\n",
    "            if done:\n",
    "                obs = self.env.reset(**kwargs)\n",
    "        return obs\n",
    "\n",
    "    def step(self, ac):\n",
    "        return self.env.step(ac)\n",
    "\n",
    "\n",
    "class EpisodicLifeEnv(gym.Wrapper):\n",
    "    def __init__(self, env):\n",
    "        '''工夫2のEpisodic Lifeです。1機失敗したときにリセットし、失敗時の状態から次を始める'''\n",
    "        gym.Wrapper.__init__(self, env)\n",
    "        self.lives = 0\n",
    "        self.was_real_done = True\n",
    "\n",
    "    def step(self, action):\n",
    "        obs, reward, done, info = self.env.step(action)\n",
    "        self.was_real_done = done\n",
    "        # check current lives, make loss of life terminal,\n",
    "        # then update lives to handle bonus lives\n",
    "        lives = self.env.unwrapped.ale.lives()\n",
    "        if lives < self.lives and lives > 0:\n",
    "            # for Qbert sometimes we stay in lives == 0 condtion for a few frames\n",
    "            # so its important to keep lives > 0, so that we only reset once\n",
    "            # the environment advertises done.\n",
    "            done = True\n",
    "        self.lives = lives\n",
    "        return obs, reward, done, info\n",
    "\n",
    "    def reset(self, **kwargs):\n",
    "        '''5機とも失敗したら、本当にリセット'''\n",
    "        if self.was_real_done:\n",
    "            obs = self.env.reset(**kwargs)\n",
    "        else:\n",
    "            # no-op step to advance from terminal/lost life state\n",
    "            obs, _, _, _ = self.env.step(0)\n",
    "        self.lives = self.env.unwrapped.ale.lives()\n",
    "        return obs\n",
    "\n",
    "\n",
    "class MaxAndSkipEnv(gym.Wrapper):\n",
    "    def __init__(self, env, skip=4):\n",
    "        '''工夫3のMax and Skipです。4フレーム連続で同じ行動を実施し、最後の3、4フレームの最大値をとった画像をobsにする'''\n",
    "        gym.Wrapper.__init__(self, env)\n",
    "        # most recent raw observations (for max pooling across time steps)\n",
    "        self._obs_buffer = np.zeros(\n",
    "            (2,)+env.observation_space.shape, dtype=np.uint8)\n",
    "        self._skip = skip\n",
    "\n",
    "    def step(self, action):\n",
    "        \"\"\"Repeat action, sum reward, and max over last observations.\"\"\"\n",
    "        total_reward = 0.0\n",
    "        done = None\n",
    "        for i in range(self._skip):\n",
    "            obs, reward, done, info = self.env.step(action)\n",
    "            if i == self._skip - 2:\n",
    "                self._obs_buffer[0] = obs\n",
    "            if i == self._skip - 1:\n",
    "                self._obs_buffer[1] = obs\n",
    "            total_reward += reward\n",
    "            if done:\n",
    "                break\n",
    "        # Note that the observation on the done=True frame\n",
    "        # doesn't matter\n",
    "        max_frame = self._obs_buffer.max(axis=0)\n",
    "\n",
    "        return max_frame, total_reward, done, info\n",
    "\n",
    "    def reset(self, **kwargs):\n",
    "        return self.env.reset(**kwargs)\n",
    "\n",
    "\n",
    "class WarpFrame(gym.ObservationWrapper):\n",
    "    def __init__(self, env):\n",
    "        '''工夫4のWarp frameです。画像サイズをNatureのDQN論文と同じ84x84の白黒にします'''\n",
    "        gym.ObservationWrapper.__init__(self, env)\n",
    "        self.width = 84\n",
    "        self.height = 84\n",
    "        self.observation_space = spaces.Box(low=0, high=255,\n",
    "                                            shape=(self.height, self.width, 1), dtype=np.uint8)\n",
    "\n",
    "    def observation(self, frame):\n",
    "        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)\n",
    "        frame = cv2.resize(frame, (self.width, self.height),\n",
    "                           interpolation=cv2.INTER_AREA)\n",
    "        return frame[:, :, None]\n",
    "\n",
    "\n",
    "class WrapPyTorch(gym.ObservationWrapper):\n",
    "    def __init__(self, env=None):\n",
    "        '''PyTorchのミニバッチのインデックス順に変更するラッパー'''\n",
    "        super(WrapPyTorch, self).__init__(env)\n",
    "        obs_shape = self.observation_space.shape\n",
    "        self.observation_space = Box(\n",
    "            self.observation_space.low[0, 0, 0],\n",
    "            self.observation_space.high[0, 0, 0],\n",
    "            [obs_shape[2], obs_shape[1], obs_shape[0]],\n",
    "            dtype=self.observation_space.dtype)\n",
    "\n",
    "    def observation(self, observation):\n",
    "        return observation.transpose(2, 0, 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 再生用の実行環境\n",
    "\n",
    "\n",
    "class EpisodicLifeEnvPlay(gym.Wrapper):\n",
    "    def __init__(self, env):\n",
    "        '''工夫2のEpisodic Lifeです。1機失敗したときにリセットし、失敗時の状態から次を始める。\n",
    "        今回は再生用に、1機失敗したときのリセット時もブロックの状態をリセットします'''\n",
    "\n",
    "        gym.Wrapper.__init__(self, env)\n",
    "\n",
    "    def step(self, action):\n",
    "        obs, reward, done, info = self.env.step(action)\n",
    "        # ライフ（残機）が始め5あるが、1つでも減ると終了にする\n",
    "        if self.env.unwrapped.ale.lives() < 5:\n",
    "            done = True\n",
    "\n",
    "        return obs, reward, done, info\n",
    "\n",
    "    def reset(self, **kwargs):\n",
    "        '''1回でも失敗したら完全リセット'''\n",
    "\n",
    "        obs = self.env.reset(**kwargs)\n",
    "\n",
    "        return obs\n",
    "\n",
    "\n",
    "class MaxAndSkipEnvPlay(gym.Wrapper):\n",
    "    def __init__(self, env, skip=4):\n",
    "        '''工夫3のMax and Skipです。4フレーム連続で同じ行動を実施し、最後の4フレームの画像をobsにする'''\n",
    "        gym.Wrapper.__init__(self, env)\n",
    "        # most recent raw observations (for max pooling across time steps)\n",
    "        self._obs_buffer = np.zeros(\n",
    "            (2,)+env.observation_space.shape, dtype=np.uint8)\n",
    "        self._skip = skip\n",
    "\n",
    "    def step(self, action):\n",
    "        \"\"\"Repeat action, sum reward, and max over last observations.\"\"\"\n",
    "        total_reward = 0.0\n",
    "        done = None\n",
    "        for i in range(self._skip):\n",
    "            obs, reward, done, info = self.env.step(action)\n",
    "            if i == self._skip - 2:\n",
    "                self._obs_buffer[0] = obs\n",
    "            if i == self._skip - 1:\n",
    "                self._obs_buffer[1] = obs\n",
    "            total_reward += reward\n",
    "            if done:\n",
    "                break\n",
    "\n",
    "        return obs, total_reward, done, info\n",
    "\n",
    "    def reset(self, **kwargs):\n",
    "        return self.env.reset(**kwargs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 実行環境生成関数の定義\n",
    "\n",
    "# 並列実行環境\n",
    "from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv\n",
    "\n",
    "\n",
    "def make_env(env_id, seed, rank):\n",
    "    def _thunk():\n",
    "        '''_thunk()がマルチプロセス環境のSubprocVecEnvを実行するのに必要'''\n",
    "\n",
    "        env = gym.make(env_id)\n",
    "        #env = NoopResetEnv(env, noop_max=30)\n",
    "        env = MaxAndSkipEnv(env, skip=4)\n",
    "        env.seed(seed + rank)  # 乱数シードの設定\n",
    "        #env = EpisodicLifeEnv(env)\n",
    "        env = EpisodicLifeEnvPlay(env)\n",
    "        env = WarpFrame(env)\n",
    "        env = WrapPyTorch(env)\n",
    "\n",
    "        return env\n",
    "\n",
    "    return _thunk\n",
    "\n",
    "\n",
    "def make_env_play(env_id, seed, rank):\n",
    "    '''再生用の実行環境'''\n",
    "    env = gym.make(env_id)\n",
    "    #env = NoopResetEnv(env, noop_max=30)\n",
    "    #env = MaxAndSkipEnv(env, skip=4)\n",
    "    env = MaxAndSkipEnvPlay(env, skip=4)\n",
    "    env.seed(seed + rank)  # 乱数シードの設定\n",
    "    env = EpisodicLifeEnvPlay(env)\n",
    "    #env = WarpFrame(env)\n",
    "    #env = WrapPyTorch(env)\n",
    "\n",
    "    return env\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定数の設定\n",
    "\n",
    "ENV_NAME = 'BreakoutNoFrameskip-v4' \n",
    "# Breakout-v0ではなく、BreakoutNoFrameskip-v4を使用\n",
    "# v0はフレームが自動的に2-4のランダムにskipされますが、今回はフレームスキップはさせないバージョンを使用\n",
    "# 参考URL https://becominghuman.ai/lets-build-an-atari-ai-part-1-dqn-df57e8ff3b26\n",
    "# https://github.com/openai/gym/blob/5cb12296274020db9bb6378ce54276b31e7002da/gym/envs/__init__.py#L371\n",
    "    \n",
    "NUM_SKIP_FRAME = 4 # skipするframe数です\n",
    "NUM_STACK_FRAME = 4  # 状態として連続的に保持するframe数です\n",
    "NOOP_MAX = 30  #  reset時に何もしないフレームを挟む（No-operation）フレーム数の乱数上限です\n",
    "NUM_PROCESSES = 16 #  並列して同時実行するプロセス数です\n",
    "NUM_ADVANCED_STEP = 5  # 何ステップ進めて報酬和を計算するのか設定\n",
    "GAMMA = 0.99  # 時間割引率\n",
    "\n",
    "TOTAL_FRAMES=10e6  #  学習に使用する総フレーム数\n",
    "NUM_UPDATES = int(TOTAL_FRAMES / NUM_ADVANCED_STEP / NUM_PROCESSES)  # ネットワークの総更新回数\n",
    "# NUM_UPDATESは125,000となる\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A2Cの損失関数の計算のための定数設定\n",
    "value_loss_coef = 0.5\n",
    "entropy_coef = 0.01\n",
    "max_grad_norm = 0.5\n",
    "\n",
    "# 学習手法RMSpropの設定\n",
    "lr = 7e-4\n",
    "eps = 1e-5\n",
    "alpha = 0.99\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cpu\n"
     ]
    }
   ],
   "source": [
    "# GPUの使用の設定\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "print(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# メモリオブジェクトの定義\n",
    "\n",
    "\n",
    "class RolloutStorage(object):\n",
    "    '''Advantage学習するためのメモリクラスです'''\n",
    "\n",
    "    def __init__(self, num_steps, num_processes, obs_shape):\n",
    "\n",
    "        self.observations = torch.zeros(\n",
    "            num_steps + 1, num_processes, *obs_shape).to(device)\n",
    "        # *を使うと()リストの中身を取り出す\n",
    "        # obs_shape→(4,84,84)\n",
    "        # *obs_shape→ 4 84 84\n",
    "\n",
    "        self.masks = torch.ones(num_steps + 1, num_processes, 1).to(device)\n",
    "        self.rewards = torch.zeros(num_steps, num_processes, 1).to(device)\n",
    "        self.actions = torch.zeros(\n",
    "            num_steps, num_processes, 1).long().to(device)\n",
    "\n",
    "        # 割引報酬和を格納\n",
    "        self.returns = torch.zeros(num_steps + 1, num_processes, 1).to(device)\n",
    "        self.index = 0  # insertするインデックス\n",
    "\n",
    "    def insert(self, current_obs, action, reward, mask):\n",
    "        '''次のindexにtransitionを格納する'''\n",
    "        self.observations[self.index + 1].copy_(current_obs)\n",
    "        self.masks[self.index + 1].copy_(mask)\n",
    "        self.rewards[self.index].copy_(reward)\n",
    "        self.actions[self.index].copy_(action)\n",
    "\n",
    "        self.index = (self.index + 1) % NUM_ADVANCED_STEP  # インデックスの更新\n",
    "\n",
    "    def after_update(self):\n",
    "        '''Advantageするstep数が完了したら、最新のものをindex0に格納'''\n",
    "        self.observations[0].copy_(self.observations[-1])\n",
    "        self.masks[0].copy_(self.masks[-1])\n",
    "\n",
    "    def compute_returns(self, next_value):\n",
    "        '''Advantageするステップ中の各ステップの割引報酬和を計算する'''\n",
    "\n",
    "        # 注意：5step目から逆向きに計算しています\n",
    "        # 注意：5step目はAdvantage1となる。4ステップ目はAdvantage2となる。・・・\n",
    "        self.returns[-1] = next_value\n",
    "        for ad_step in reversed(range(self.rewards.size(0))):\n",
    "            self.returns[ad_step] = self.returns[ad_step + 1] * \\\n",
    "                GAMMA * self.masks[ad_step + 1] + self.rewards[ad_step]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A2Cのディープ・ニューラルネットワークの構築\n",
    "\n",
    "\n",
    "def init(module, gain):\n",
    "    '''層の結合パラメータを初期化する関数を定義'''\n",
    "    nn.init.orthogonal_(module.weight.data, gain=gain)\n",
    "    nn.init.constant_(module.bias.data, 0)\n",
    "    return module\n",
    "\n",
    "\n",
    "class Flatten(nn.Module):\n",
    "    '''コンボリューション層の出力画像を1次元に変換する層を定義'''\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x.view(x.size(0), -1)\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self, n_out):\n",
    "        super(Net, self).__init__()\n",
    "\n",
    "        # 結合パラメータの初期化関数\n",
    "        def init_(module): return init(\n",
    "            module, gain=nn.init.calculate_gain('relu'))\n",
    "\n",
    "        # コンボリューション層の定義\n",
    "        self.conv = nn.Sequential(\n",
    "            # 画像サイズの変化84*84→20*20\n",
    "            init_(nn.Conv2d(NUM_STACK_FRAME, 32, kernel_size=8, stride=4)),\n",
    "            # stackするflameは4画像なのでinput=NUM_STACK_FRAME=4である、出力は32とする、\n",
    "            # sizeの計算  size = (Input_size - Kernel_size + 2*Padding_size)/ Stride_size + 1\n",
    "\n",
    "            nn.ReLU(),\n",
    "            # 画像サイズの変化20*20→9*9\n",
    "            init_(nn.Conv2d(32, 64, kernel_size=4, stride=2)),\n",
    "            nn.ReLU(),\n",
    "            init_(nn.Conv2d(64, 64, kernel_size=3, stride=1)),  # 画像サイズの変化9*9→7*7\n",
    "            nn.ReLU(),\n",
    "            Flatten(),  # 画像形式を1次元に変換\n",
    "            init_(nn.Linear(64 * 7 * 7, 512)),  # 64枚の7×7の画像を、512次元のoutputへ\n",
    "            nn.ReLU()\n",
    "        )\n",
    "\n",
    "        # 結合パラメータの初期化関数\n",
    "        def init_(module): return init(module, gain=1.0)\n",
    "\n",
    "        # Criticの定義\n",
    "        self.critic = init_(nn.Linear(512, 1))  # 状態価値なので出力は1つ\n",
    "\n",
    "        # 結合パラメータの初期化関数\n",
    "        def init_(module): return init(module, gain=0.01)\n",
    "\n",
    "        # Actorの定義\n",
    "        self.actor = init_(nn.Linear(512, n_out))  # 行動を決めるので出力は行動の種類数\n",
    "\n",
    "        # ネットワークを訓練モードに設定\n",
    "        self.train()\n",
    "\n",
    "    def forward(self, x):\n",
    "        '''ネットワークのフォワード計算を定義します'''\n",
    "        input = x / 255.0  # 画像のピクセル値0-255を0-1に正規化する\n",
    "        conv_output = self.conv(input)  # Convolution層の計算\n",
    "        critic_output = self.critic(conv_output)  # 状態価値の計算\n",
    "        actor_output = self.actor(conv_output)  # 行動の計算\n",
    "\n",
    "        return critic_output, actor_output\n",
    "\n",
    "    def act(self, x):\n",
    "        '''状態xから行動を確率的に求めます'''\n",
    "        value, actor_output = self(x)\n",
    "        probs = F.softmax(actor_output, dim=1)    # dim=1で行動の種類方向に計算\n",
    "        action = probs.multinomial(num_samples=1)\n",
    "\n",
    "        return action\n",
    "\n",
    "    def get_value(self, x):\n",
    "        '''状態xから状態価値を求めます'''\n",
    "        value, actor_output = self(x)\n",
    "\n",
    "        return value\n",
    "\n",
    "    def evaluate_actions(self, x, actions):\n",
    "        '''状態xから状態価値、実際の行動actionsのlog確率とエントロピーを求めます'''\n",
    "        value, actor_output = self(x)\n",
    "\n",
    "        log_probs = F.log_softmax(actor_output, dim=1)  # dim=1で行動の種類方向に計算\n",
    "        action_log_probs = log_probs.gather(1, actions)  # 実際の行動のlog_probsを求める\n",
    "\n",
    "        probs = F.softmax(actor_output, dim=1)  # dim=1で行動の種類方向に計算\n",
    "        dist_entropy = -(log_probs * probs).sum(-1).mean()\n",
    "\n",
    "        return value, action_log_probs, dist_entropy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# エージェントが持つ頭脳となるクラスを定義、全エージェントで共有する\n",
    "\n",
    "\n",
    "class Brain(object):\n",
    "    def __init__(self, actor_critic):\n",
    "\n",
    "        self.actor_critic = actor_critic  # actor_criticはクラスNetのディープ・ニューラルネットワーク\n",
    "\n",
    "        # 結合パラメータをロードする場合\n",
    "        filename = 'weight_end.pth'\n",
    "        #filename = 'weight_112500.pth'\n",
    "        param = torch.load(filename, map_location='cpu')\n",
    "        self.actor_critic.load_state_dict(param)\n",
    "\n",
    "        # パラメータ更新の勾配法の設定\n",
    "        self.optimizer = optim.RMSprop(\n",
    "            actor_critic.parameters(), lr=lr, eps=eps, alpha=alpha)\n",
    "\n",
    "    def update(self, rollouts):\n",
    "        '''advanced計算した5つのstepの全てを使って更新します'''\n",
    "        obs_shape = rollouts.observations.size()[2:]  # torch.Size([4, 84, 84])\n",
    "        num_steps = NUM_ADVANCED_STEP\n",
    "        num_processes = NUM_PROCESSES\n",
    "\n",
    "        values, action_log_probs, dist_entropy = self.actor_critic.evaluate_actions(\n",
    "            rollouts.observations[:-1].view(-1, *obs_shape),\n",
    "            rollouts.actions.view(-1, 1))\n",
    "\n",
    "        # 注意：各変数のサイズ\n",
    "        # rollouts.observations[:-1].view(-1, *obs_shape) torch.Size([80, 4, 84, 84])\n",
    "        # rollouts.actions.view(-1, 1) torch.Size([80, 1])\n",
    "        # values torch.Size([80, 1])\n",
    "        # action_log_probs torch.Size([80, 1])\n",
    "        # dist_entropy torch.Size([])\n",
    "\n",
    "        values = values.view(num_steps, num_processes,\n",
    "                             1)  # torch.Size([5, 16, 1])\n",
    "        action_log_probs = action_log_probs.view(num_steps, num_processes, 1)\n",
    "\n",
    "        advantages = rollouts.returns[:-1] - values  # torch.Size([5, 16, 1])\n",
    "        value_loss = advantages.pow(2).mean()\n",
    "\n",
    "        action_gain = (advantages.detach() * action_log_probs).mean()\n",
    "        # detachしてadvantagesを定数として扱う\n",
    "\n",
    "        total_loss = (value_loss * value_loss_coef -\n",
    "                      action_gain - dist_entropy * entropy_coef)\n",
    "\n",
    "        self.optimizer.zero_grad()  # 勾配をリセット\n",
    "        total_loss.backward()  # バックプロパゲーションを計算\n",
    "        nn.utils.clip_grad_norm_(self.actor_critic.parameters(), max_grad_norm)\n",
    "        #  一気に結合パラメータが変化しすぎないように、勾配の大きさは最大0.5までにする\n",
    "\n",
    "        self.optimizer.step()  # 結合パラメータを更新\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Breakoutを実行する環境のクラス\n",
    "\n",
    "NUM_PROCESSES = 1\n",
    "\n",
    "\n",
    "class Environment:\n",
    "    def run(self):\n",
    "\n",
    "        # seedの設定\n",
    "        seed_num = 1\n",
    "        torch.manual_seed(seed_num)\n",
    "        if use_cuda:\n",
    "            torch.cuda.manual_seed(seed_num)\n",
    "\n",
    "        # 実行環境を構築\n",
    "        torch.set_num_threads(seed_num)\n",
    "        envs = [make_env(ENV_NAME, seed_num, i) for i in range(NUM_PROCESSES)]\n",
    "        envs = SubprocVecEnv(envs)  # マルチプロセスの実行環境にする\n",
    "\n",
    "        # 全エージェントが共有して持つ頭脳Brainを生成\n",
    "        n_out = envs.action_space.n  # 行動の種類は4\n",
    "        actor_critic = Net(n_out).to(device)  # GPUへ\n",
    "        global_brain = Brain(actor_critic)\n",
    "\n",
    "        # 格納用変数の生成\n",
    "        obs_shape = envs.observation_space.shape  # (1, 84, 84)\n",
    "        obs_shape = (obs_shape[0] * NUM_STACK_FRAME,\n",
    "                     *obs_shape[1:])  # (4, 84, 84)\n",
    "        # torch.Size([16, 4, 84, 84])\n",
    "        current_obs = torch.zeros(NUM_PROCESSES, *obs_shape).to(device)\n",
    "        rollouts = RolloutStorage(\n",
    "            NUM_ADVANCED_STEP, NUM_PROCESSES, obs_shape)  # rolloutsのオブジェクト\n",
    "        episode_rewards = torch.zeros([NUM_PROCESSES, 1])  # 現在の試行の報酬を保持\n",
    "        final_rewards = torch.zeros([NUM_PROCESSES, 1])  # 最後の試行の報酬和を保持\n",
    "\n",
    "        # 初期状態の開始\n",
    "        obs = envs.reset()\n",
    "        obs = torch.from_numpy(obs).float()  # torch.Size([16, 1, 84, 84])\n",
    "        current_obs[:, -1:] = obs  # flameの4番目に最新のobsを格納\n",
    "\n",
    "        # advanced学習用のオブジェクトrolloutsの状態の1つ目に、現在の状態を保存\n",
    "        rollouts.observations[0].copy_(current_obs)\n",
    "\n",
    "        # 描画用の環境（再生用に追加）\n",
    "        env_play = make_env_play(ENV_NAME, seed_num, 0)\n",
    "        obs_play = env_play.reset()\n",
    "\n",
    "        # 動画にするために画像を格納する変数（再生用に追加）\n",
    "        frames = []\n",
    "        main_end = False\n",
    "\n",
    "        # 実行ループ\n",
    "        for j in tqdm(range(NUM_UPDATES)):\n",
    "\n",
    "            # 報酬が基準を超えたら終わりにする（再生用に追加）\n",
    "            if main_end:\n",
    "                break\n",
    "\n",
    "            # advanced学習するstep数ごとに計算\n",
    "            for step in range(NUM_ADVANCED_STEP):\n",
    "\n",
    "                # 行動を求める\n",
    "                with torch.no_grad():\n",
    "                    action = actor_critic.act(rollouts.observations[step])\n",
    "\n",
    "                cpu_actions = action.squeeze(1).cpu().numpy()  # tensorをNumPyに\n",
    "\n",
    "                # 1stepの並列実行、なお返り値のobsのsizeは(16, 1, 84, 84)\n",
    "                obs, reward, done, info = envs.step(cpu_actions)\n",
    "\n",
    "                # 報酬をtensorに変換し、試行の総報酬に足す\n",
    "                # sizeが(16,)になっているのを(16, 1)に変換\n",
    "                reward = np.expand_dims(np.stack(reward), 1)\n",
    "                reward = torch.from_numpy(reward).float()\n",
    "                episode_rewards += reward\n",
    "\n",
    "                # 各実行環境それぞれについて、doneならmaskは0に、継続中ならmaskは1にする\n",
    "                masks = torch.FloatTensor(\n",
    "                    [[0.0] if done_ else [1.0] for done_ in done])\n",
    "\n",
    "                # 最後の試行の総報酬を更新する\n",
    "                final_rewards *= masks  # 継続中の場合は1をかけ算してそのまま、done時には0を掛けてリセット\n",
    "                # 継続中は0を足す、done時にはepisode_rewardsを足す\n",
    "                final_rewards += (1 - masks) * episode_rewards\n",
    "\n",
    "                # 画像を取得する(再生用に追加）\n",
    "                obs_play, reward_play, _, _ = env_play.step(cpu_actions[0])\n",
    "                frames.append(obs_play)  # 変換した画像を保存\n",
    "                if done[0]:  # 並列環境の1つ目が終了した場合\n",
    "                    print(episode_rewards[0][0].numpy())  # 報酬\n",
    "\n",
    "                    # 報酬が300を超えたら終わりにする\n",
    "                    if (episode_rewards[0][0].numpy()) > 300:\n",
    "                        main_end = True\n",
    "                        break\n",
    "                    else:\n",
    "                        obs_view = env_play.reset()\n",
    "                        frames = []  # 保存した画像をリセット\n",
    "\n",
    "                # 試行の総報酬を更新する\n",
    "                episode_rewards *= masks  # 継続中のmaskは1なのでそのまま、doneの場合は0に\n",
    "\n",
    "                # masksをGPUへ\n",
    "                masks = masks.to(device)\n",
    "\n",
    "                # 現在の状態をdone時には全部0にする\n",
    "                # maskのサイズをtorch.Size([16, 1])→torch.Size([16, 1, 1 ,1])へ変換して、かけ算\n",
    "                current_obs *= masks.unsqueeze(2).unsqueeze(2)\n",
    "\n",
    "                # frameをstackする\n",
    "                # torch.Size([16, 1, 84, 84])\n",
    "                obs = torch.from_numpy(obs).float()\n",
    "                current_obs[:, :-1] = current_obs[:, 1:]  # 0～2番目に1～3番目を上書き\n",
    "                current_obs[:, -1:] = obs  # 4番目に最新のobsを格納\n",
    "\n",
    "                # メモリオブジェクトに今stepのtransitionを挿入\n",
    "                rollouts.insert(current_obs, action.data, reward, masks)\n",
    "\n",
    "            # advancedのfor loop終了\n",
    "\n",
    "            # advancedした最終stepの状態から予想する状態価値を計算\n",
    "            with torch.no_grad():\n",
    "                next_value = actor_critic.get_value(\n",
    "                    rollouts.observations[-1]).detach()\n",
    "\n",
    "            # 全stepの割引報酬和を計算して、rolloutsの変数returnsを更新\n",
    "            rollouts.compute_returns(next_value)\n",
    "\n",
    "            # ネットワークとrolloutの更新\n",
    "            # global_brain.update(rollouts)\n",
    "            rollouts.after_update()\n",
    "\n",
    "        # 実行ループ終わり\n",
    "        display_frames_as_gif(frames)  # 動画の保存と再生\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 実行\n",
    "breakout_env = Environment()\n",
    "frames = breakout_env.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
